1   /*
2    * Copyright (c) 2001, 2003, Oracle and/or its affiliates. All rights reserved.
3    * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4    *
5    * This code is free software; you can redistribute it and/or modify it
6    * under the terms of the GNU General Public License version 2 only, as
7    * published by the Free Software Foundation.
8    *
9    * This code is distributed in the hope that it will be useful, but WITHOUT
10   * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
11   * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
12   * version 2 for more details (a copy is included in the LICENSE file that
13   * accompanied this code).
14   *
15   * You should have received a copy of the GNU General Public License version
16   * 2 along with this work; if not, write to the Free Software Foundation,
17   * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
18   *
19   * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
20   * or visit www.oracle.com if you need additional information or have any
21   * questions.
22   */
23  
24  /*
25   * @test
26   * @bug 4416068 4478803 4479736
27   * @summary 4273544 JSSE request for function forceV3ClientHello()
28   *          4479736 setEnabledProtocols API does not work correctly
29   *          4478803 Need APIs to determine the protocol versions used in an SSL
30   *                  session
31   *          4701722 protocol mismatch exceptions should be consistent between
32   *                  SSLv3 and TLSv1
33   * @author Ram Marti
34   */
35  
36  import java.io.*;
37  import java.net.*;
38  import java.util.*;
39  import java.security.*;
40  import javax.net.ssl.*;
41  import java.security.cert.*;
42  
43  public class testEnabledProtocols {
44  
45      /*
46       * For each of the valid protocols combinations, start a server thread
47       * that sets up an SSLServerSocket supporting that protocol. Then run
48       * a client thread that attemps to open a connection with all
49       * possible protocol combinataion.  Verify that we get handshake
50       * exceptions correctly. Whenever the connection is established
51       * successfully, verify that the negotiated protocol was correct.
52       * See results file in this directory for complete results.
53       */
54  
55      static final String[][] protocolStrings = {
56                                  {"TLSv1"},
57                                  {"TLSv1", "SSLv2Hello"},
58                                  {"TLSv1", "SSLv3"},
59                                  {"SSLv3", "SSLv2Hello"},
60                                  {"SSLv3"},
61                                  {"TLSv1", "SSLv3", "SSLv2Hello"}
62                                  };
63  
64      static final boolean [][] eXceptionArray = {
65          // Do we expect exception?       Protocols supported by the server
66          { false, true,  false, true,  true,  true }, // TLSv1
67          { false, false, false, true,  true,  false}, // TLSv1,SSLv2Hello
68          { false, true,  false, true,  false, true }, // TLSv1,SSLv3
69          { true,  true,  false, false, false, false}, // SSLv3, SSLv2Hello
70          { true,  true,  false, true,  false, true }, // SSLv3
71          { false, false, false, false, false, false } // TLSv1,SSLv3,SSLv2Hello
72          };
73  
74      static final String[][] protocolSelected = {
75          // TLSv1
76          { "TLSv1",  null,   "TLSv1",  null,   null,     null },
77  
78          // TLSv1,SSLv2Hello
79          { "TLSv1", "TLSv1", "TLSv1",  null,   null,    "TLSv1"},
80  
81          // TLSv1,SSLv3
82          { "TLSv1",  null,   "TLSv1",  null,   "SSLv3",  null },
83  
84          // SSLv3, SSLv2Hello
85          {  null,    null,   "SSLv3", "SSLv3", "SSLv3",  "SSLv3"},
86  
87          // SSLv3
88          {  null,    null,   "SSLv3",  null,   "SSLv3",  null },
89  
90          // TLSv1,SSLv3,SSLv2Hello
91          { "TLSv1", "TLSv1", "TLSv1", "SSLv3", "SSLv3", "TLSv1" }
92  
93      };
94  
95      /*
96       * Where do we find the keystores?
97       */
98      final static String pathToStores = "../../../../etc";
99      static String passwd = "passphrase";
100     static String keyStoreFile = "keystore";
101     static String trustStoreFile = "truststore";
102 
103     /*
104      * Is the server ready to serve?
105      */
106     volatile static boolean serverReady = false;
107 
108     /*
109      * Turn on SSL debugging?
110      */
111     final static boolean debug = false;
112 
113     // use any free port by default
114     volatile int serverPort = 0;
115 
116     volatile Exception clientException = null;
117 
118     public static void main(String[] args) throws Exception {
119         String keyFilename =
120             System.getProperty("test.src", "./") + "/" + pathToStores +
121                 "/" + keyStoreFile;
122         String trustFilename =
123             System.getProperty("test.src", "./") + "/" + pathToStores +
124                 "/" + trustStoreFile;
125 
126         System.setProperty("javax.net.ssl.keyStore", keyFilename);
127         System.setProperty("javax.net.ssl.keyStorePassword", passwd);
128         System.setProperty("javax.net.ssl.trustStore", trustFilename);
129         System.setProperty("javax.net.ssl.trustStorePassword", passwd);
130 
131         if (debug)
132             System.setProperty("javax.net.debug", "all");
133 
134         new testEnabledProtocols();
135     }
136 
137     testEnabledProtocols() throws Exception  {
138         /*
139          * Start the tests.
140          */
141         SSLServerSocketFactory sslssf =
142             (SSLServerSocketFactory) SSLServerSocketFactory.getDefault();
143         SSLServerSocket sslServerSocket =
144             (SSLServerSocket) sslssf.createServerSocket(serverPort);
145         serverPort = sslServerSocket.getLocalPort();
146         // sslServerSocket.setNeedClientAuth(true);
147 
148         for (int i = 0; i < protocolStrings.length; i++) {
149             String [] serverProtocols = protocolStrings[i];
150             startServer ss = new startServer(serverProtocols,
151                 sslServerSocket, protocolStrings.length);
152             ss.setDaemon(true);
153             ss.start();
154             for (int j = 0; j < protocolStrings.length; j++) {
155                 String [] clientProtocols = protocolStrings[j];
156                 startClient sc = new startClient(
157                     clientProtocols, serverProtocols,
158                     eXceptionArray[i][j], protocolSelected[i][j]);
159                 sc.start();
160                 sc.join();
161                 if (clientException != null) {
162                     ss.requestStop();
163                     throw clientException;
164                 }
165             }
166             ss.requestStop();
167             System.out.println("Waiting for the server to complete");
168             ss.join();
169         }
170     }
171 
172     class startServer extends Thread  {
173         private String[] enabledP = null;
174         SSLServerSocket sslServerSocket = null;
175         int numExpConns;
176         volatile boolean stopRequested = false;
177 
178         public startServer(String[] enabledProtocols,
179                             SSLServerSocket sslServerSocket,
180                             int numExpConns) {
181             super("Server Thread");
182             serverReady = false;
183             enabledP = enabledProtocols;
184             this.sslServerSocket = sslServerSocket;
185             sslServerSocket.setEnabledProtocols(enabledP);
186             this.numExpConns = numExpConns;
187         }
188 
189         public void requestStop() {
190             stopRequested = true;
191         }
192 
193         public void run() {
194             int conns = 0;
195             while (!stopRequested) {
196                 SSLSocket socket = null;
197                 try {
198                     serverReady = true;
199                     socket = (SSLSocket)sslServerSocket.accept();
200                     conns++;
201 
202                     // set ready to false. this is just to make the
203                     // client wait and synchronise exception messages
204                     serverReady = false;
205                     socket.startHandshake();
206                     SSLSession session = socket.getSession();
207                     session.invalidate();
208 
209                     InputStream in = socket.getInputStream();
210                     OutputStream out = socket.getOutputStream();
211                     out.write(280);
212                     in.read();
213 
214                     socket.close();
215                     // sleep for a while so that the server thread can be
216                     // stopped
217                     Thread.sleep(30);
218                 } catch (SSLHandshakeException se) {
219                     // ignore it; this is part of the testing
220                     // log it for debugging
221                     System.out.println("Server SSLHandshakeException:");
222                     se.printStackTrace(System.out);
223                 } catch (java.io.InterruptedIOException ioe) {
224                     // must have been interrupted, no harm
225                     break;
226                 } catch (java.lang.InterruptedException ie) {
227                     // must have been interrupted, no harm
228                     break;
229                 } catch (Exception e) {
230                     System.out.println("Server exception:");
231                     e.printStackTrace(System.out);
232                     throw new RuntimeException(e);
233                 } finally {
234                     try {
235                         if (socket != null) {
236                             socket.close();
237                         }
238                     } catch (IOException e) {
239                         // ignore
240                     }
241                 }
242                 if (conns >= numExpConns) {
243                     break;
244                 }
245             }
246         }
247     }
248 
249     private static void showProtocols(String name, String[] protocols) {
250         System.out.println("Enabled protocols on the " + name + " are: " + Arrays.asList(protocols));
251     }
252 
253     class startClient extends Thread {
254         boolean hsCompleted = false;
255         boolean exceptionExpected = false;
256         private String[] enabledP = null;
257         private String[] serverP = null; // used to print the result
258         private String protocolToUse = null;
259 
260         startClient(String[] enabledProtocol,
261                     String[] serverP,
262                     boolean eXception,
263                     String protocol) throws Exception {
264             super("Client Thread");
265             this.enabledP = enabledProtocol;
266             this.serverP = serverP;
267             this.exceptionExpected = eXception;
268             this.protocolToUse = protocol;
269         }
270 
271         public void run() {
272             SSLSocket sslSocket = null;
273             try {
274                 while (!serverReady) {
275                     Thread.sleep(50);
276                 }
277                 System.out.flush();
278                 System.out.println("=== Starting new test run ===");
279                 showProtocols("server", serverP);
280                 showProtocols("client", enabledP);
281 
282                 SSLSocketFactory sslsf =
283                     (SSLSocketFactory)SSLSocketFactory.getDefault();
284                 sslSocket = (SSLSocket)
285                     sslsf.createSocket("localhost", serverPort);
286                 sslSocket.setEnabledProtocols(enabledP);
287                 sslSocket.startHandshake();
288 
289                 SSLSession session = sslSocket.getSession();
290                 session.invalidate();
291                 String protocolName = session.getProtocol();
292                 System.out.println("Protocol name after getSession is " +
293                     protocolName);
294 
295                 if (protocolName.equals(protocolToUse)) {
296                     System.out.println("** Success **");
297                 } else {
298                     System.out.println("** FAILURE ** ");
299                     throw new RuntimeException
300                         ("expected protocol " + protocolToUse +
301                          " but using " + protocolName);
302                 }
303 
304                 InputStream in = sslSocket.getInputStream();
305                 OutputStream out = sslSocket.getOutputStream();
306                 in.read();
307                 out.write(280);
308 
309                 sslSocket.close();
310 
311             } catch (SSLHandshakeException e) {
312                 if (!exceptionExpected) {
313                     System.out.println("Client got UNEXPECTED SSLHandshakeException:");
314                     e.printStackTrace(System.out);
315                     System.out.println("** FAILURE **");
316                     clientException = e;
317                 } else {
318                     System.out.println("Client got expected SSLHandshakeException:");
319                     e.printStackTrace(System.out);
320                     System.out.println("** Success **");
321                 }
322             } catch (RuntimeException e) {
323                 clientException = e;
324             } catch (Exception e) {
325                 System.out.println("Client got UNEXPECTED Exception:");
326                 e.printStackTrace(System.out);
327                 System.out.println("** FAILURE **");
328                 clientException = e;
329             }
330         }
331     }
332 
333 }